/*
 * Copyright (c) 2020 - present, Inspur Genersoft Co., Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.iec.edp.caf.context.memory;

import io.iec.edp.caf.context.core.listener.BizContextListener;
import io.iec.edp.caf.core.context.*;
import io.iec.edp.caf.core.session.CurrentSessionManager;
import io.iec.edp.caf.core.session.Session;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * This is {@link InMemoryBizContextManager}.
 *
 * @author yisiqi
 * @since 1.0.0
 */
@Slf4j
public class InMemoryBizContextManager implements BizContextManager {

    private final CurrentSessionManager currentWebSessionManager;
    private final Map<Class<? extends BizContext>, BizContextBuilder> buildersMap = new ConcurrentHashMap<>();

    private final Map<String, BizContext> contextRepo = new ConcurrentHashMap<>();
    private final Map<String, List<String>> sessionMap = new HashMap<>();
    private final Map<String, List<BizContextListener>> listenersMap = new ConcurrentHashMap<>();


    @SuppressWarnings("unchecked")
    public InMemoryBizContextManager(CurrentSessionManager manager, List<BizContextBuilder> builders) {
        this.currentWebSessionManager = manager;
        if (builders != null && builders.size() > 0) {
            builders.forEach((b) -> buildersMap.put(b.getContextType(), b));
        }
    }


    @Override
    @SneakyThrows(BizContextNotFoundException.class)
    public <C extends BizContext> C createRootContext(Class<C> contextType) throws BizContextBuilderNotFoundException {
        return buildNewContext(contextType, null, true);
    }

    @Override
    public <C extends BizContext> C createContextUnder(Class<C> contextType, BizContext parent) throws BizContextNotFoundException, BizContextBuilderNotFoundException {
        return createContextUnder(contextType, parent.getId());
    }

    @Override
    public <C extends BizContext> C createRootContext(Class<C> contextType, long expiration, TimeUnit unit, String serializerType) throws BizContextBuilderNotFoundException {
        return createRootContext(contextType);
    }

    @Override
    public <C extends BizContext> C createContextUnder(Class<C> contextType, String parentId, long expiration, TimeUnit unit, String serializerType) throws BizContextNotFoundException, BizContextBuilderNotFoundException {
        return createContextUnder(contextType, parentId);
    }

    @Override
    public <C extends BizContext> C createContextUnder(Class<C> contextType, BizContext parent, long expiration, TimeUnit unit, String serializerType) throws BizContextNotFoundException, BizContextBuilderNotFoundException {
        return createContextUnder(contextType, parent);
    }

    @Override
    public <C extends BizContext> C createContextUnder(Class<C> contextType, String parentId) throws BizContextNotFoundException, BizContextBuilderNotFoundException {
        BizContext parent = contextRepo.get(parentId);
        if (parent == null) {
            throw new BizContextNotFoundException();
        }
        return buildNewContext(contextType, parent, false);
    }

    @SuppressWarnings("unchecked")
    private <C extends BizContext> C buildNewContext(
            Class<C> contextType, Session parent, boolean skipParentCheck
    ) throws BizContextBuilderNotFoundException, BizContextNotFoundException {
        if (!skipParentCheck && (parent == null || contextRepo.get(parent.getId()) == null)) {
            throw new BizContextNotFoundException();
        }
        //根据contextType查找对应的builder构造器
        BizContextBuilder builder = buildersMap.get(contextType);
        if (builder == null) {
            throw new BizContextBuilderNotFoundException();
        }

        try {
            //创建BizContext
            C context = (C) builder.build(parent);
            this.save(context);
            return context;
        } catch (ClassCastException ex) {
            log.error("Failed to build context of type: {}. Please check the context builder: {}.",
                    contextType.getCanonicalName(), builder.getClass().getCanonicalName());
            return null;
        }
    }


    @Override
    @SuppressWarnings("unchecked")
    public <C extends BizContext> C fetch(Class<C> contextType, String contextId) throws BizContextNotFoundException {
        return (C) fetch(contextId);
    }

    @Override
    public BizContext fetch(String contextId) throws BizContextNotFoundException {
        BizContext context = contextRepo.get(contextId);
        if (context == null) {
            throw new BizContextNotFoundException();
        }
        return context;
    }

    @Override
    public Class<? extends BizContext> deduce(String contextId) {
        // TODO: 实现类型推断
        return BizContext.class;
    }

    /**
     * 更新上下文类型
     *
     * @param context 指定的上下文
     */
    @Override
    public void update(BizContext context) {
        //in-Memory什么也不需要干
        if (context != null)
            context.setItemsChanged(false);
    }

    @Override
    public <C extends BizContext> String register(Class<C> contextType, BizContextListener<C> listener) {
        String contextTypeName = contextType.getCanonicalName();
        List<BizContextListener> listeners = listenersMap.computeIfAbsent(contextTypeName, k -> new ArrayList<>());
        listeners.add(listener);

        return Base64.getEncoder().encodeToString((contextTypeName + ":" + (listeners.size() - 1)).getBytes(StandardCharsets.UTF_8));
    }

    @Override
    public void unRegister(String registrationId) {
        String[] decodedRegistration = new String(Base64.getDecoder().decode(registrationId), StandardCharsets.UTF_8).split(":");
        if (decodedRegistration.length != 2) {
            return;
        }

        String contextTypeName = decodedRegistration[0];
        int index = Integer.parseInt(decodedRegistration[1]);
        List<BizContextListener> listeners = listenersMap.get(contextTypeName);

        if (listeners == null || listeners.isEmpty() || listeners.get(index) == null) return;

        listeners.remove(index);
    }

    @Override
    public <C extends BizContext> void unRegister(Class<C> contextType, BizContextListener<C> listener) {
        String contextTypeName = contextType.getCanonicalName();
        List<BizContextListener> listeners = listenersMap.get(contextTypeName);
        if (listeners == null || listeners.isEmpty()) return;
        listeners.remove(listener);
    }

    @Override
    public void destroy(String contextId) throws BizContextNotFoundException {
        BizContext context = fetch(contextId);
        if (context != null) {
            this.contextRepo.remove(contextId);
        }
    }

    /**
     * 销毁上下文
     *
     * @param sessionId 指定上下文ID
     */
    @Override
    public void destroyBySession(String sessionId) {
        if (sessionId == null || sessionId.isEmpty() ||
                sessionMap == null || !sessionMap.containsKey(sessionId))
            return;
        //取出当前的children集合
        List<String> contextIds = sessionMap.get(sessionId);
        //清理当前分组
        this.sessionMap.remove(sessionId);
        //循环删除
        for (String item : contextIds) {
            try {
                //递归调用
                this.destroy(item, false);
            } catch (Exception ex) {
                log.error(ex.getMessage(), ex);
            }
        }
    }

    @Override
    public void destroy(BizContext context) throws BizContextNotFoundException {
        destroy(context.getId());
    }

    @Override
    public void destroy(String contextId, Boolean cascade) throws BizContextNotFoundException {
        destroy(contextId);
        if (cascade) {
            List<String> trash = new ArrayList<>();
            contextRepo.forEach((key, value) -> {
                if (value.getParent().getId().equals(contextId)) {
                    trash.add(key);
                }
            });
            for (String s : trash) {
                contextRepo.remove(s);
            }
        }
    }

    @Override
    public void destroy(BizContext context, Boolean cascade) throws BizContextNotFoundException {
        destroy(context.getId(), cascade);
    }

    /**
     * 判断Session是否过期
     *
     * @param contextId
     * @return
     */
    @Override
    public boolean isExpired(String contextId) {
        BizContext context = contextRepo.get(contextId);
        return context == null;
    }

    @Override
    public void prolong(String contextId) {
    }

    /**
     * 保存业务上下文
     *
     * @param context 上下文实体
     */
    private void save(BizContext context) {
        //添加到仓储
        contextRepo.put(context.getId(), context);
        //将sessionId ~ contextId的关系记入set结构
        if (currentWebSessionManager.getCurrentSession() != null) {
            String sessionId = currentWebSessionManager.getCurrentSession().getId();
            if (!sessionMap.containsKey(sessionId)) {
                sessionMap.put(sessionId, new ArrayList<>());
            }
            sessionMap.get(sessionId).add(context.getId());
        }
    }
}
